tensorflow中tf.concat的axis的使用我一直理解的比较模糊,这次做个笔记理下自己的思路。
1 | import tensorflow as tf |
先生成两个矩阵m1, 和m2, 大小为两行三列
1 | m1 = np.random.rand(2,3) # m1.shape (2,3) |
接下来采用tf.concat进行连接,简单来说,axis=0实际就是按行拼接,axis=1就是按列拼接
1 | # axis = 0 |
但这实际上这只有在我们的输入是二维矩阵时才可以这样理解。axis的实际含义是根据axis指定的维度进行连接,如矩阵m1的维度为(2,3), 那么axis=0就代表了第一个维度‘2’,因此,将m1和m2按照第一个维度进行连接,得到的新的矩阵就是将第一维度进行相加,其余维度不变,即维度变成了(4,3).
同理,axis=1时就是将矩阵的第二维度进行合并,其余维度不变,即维度变成了(2,6)。
接下来处理三个维度的数据,这也是我们在神经网络数据中经常要用到的,增加的一个维度通常代表了batch_size. 如下面的m5, batch_size=5, 可以理解为每个样本是个2*3的矩阵,一次将5个样本放在一起。
1 | m5 = np.random.rand(5,2,3) |
在这种情况下,axis=0代表的第一个维度的含义就不再是之前认为的行的概念了,现在m5的第一维度的值是5,代表的是batch_size。仍然按照之前的理解,如果设置axis=0, axis=0就是将第一维度进行相加,其余维度不变,因此我们可以得到新的维度为(10,2,3)。
1 | m7 = tf.concat([m5, m6],axis=0) |
同理,也可以进行axis=1, axis=2的concat操作。
此外,axis的值也可以设置为负数,如axis=-1实际上就是指倒数第一个维度,如m5的倒数第一个维度的值就是‘3’。因此,axis=2的操作和axis=-1的操作是等价的。